'''
This is a pseudo-code to help you understand the paper.
The entire source code is planned to be released to public.
'''
import torch
import torch.nn as nn
import torch.nn.functional as F
from modules.model import BVAETTS
import hparams
from text import *
from utils.utils import *


    
def main():
    train_loader, val_loader, collate_fn = prepare_dataloaders(hparams)
    model = BVAETTS(hparams).cuda()
    optimizer = torch.optim.Adamax(model.parameters(), lr=hparams.lr)

    iteration, recon_loss, kl_loss, duration_loss, align_loss = 0, 0, 0, 0, 0
    model.train()
    while iteration < hparams.train_steps:
        for i, batch in enumerate(train_loader):
            text_padded, text_lengths, mel_padded, mel_lengths = [ x.cuda() for x in batch ]
            recon_loss, kl_loss, duration_loss, align_loss = model(text_padded, mel_padded, text_lengths, mel_lengths)

            alpha=min(1, iteration/60000)
            loss = (recon_sub_loss + alpha*kl_sub_loss + duration_sub_loss + align_sub_loss)
            loss.backward()
            iteration += 1
            
            lr_scheduling(optimizer, iteration//hparams.accumulation)
            nn.utils.clip_grad_norm_(model.parameters(), hparams.grad_clip_thresh)
            optimizer.step()
            model.zero_grad()
                
            if iteration % hparams.iters_per_validation == 0:
                validate(model, val_loader, iteration, writer)
                
            if iteration == hparams.train_steps:
                break
                
                
if __name__ == '__main__':
    main()